{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from network import Network\n",
    "from conv import Conv\n",
    "from dense import Dense\n",
    "from activation import Relu\n",
    "from metrics import log_loss, log_loss_grad, softmax\n",
    "import math\n",
    "import numpy as np\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "random.seed(0)\n",
    "FILE = \"fashion-mnist_train.csv\"\n",
    "with open(f\"../data/{FILE}\") as f:\n",
    "    examples = f.read().strip().split(\"\\n\")[1:]\n",
    "\n",
    "train, valid, test = [], [], []\n",
    "valid_cutoff = .8\n",
    "test_cutoff = .9\n",
    "\n",
    "for example in examples:\n",
    "    label, pixels = example.split(\",\", 1)\n",
    "    label = int(label)\n",
    "    if label > 4:\n",
    "        continue\n",
    "    pixels = [int(p) for p in pixels.split(\",\")]\n",
    "    data = np.asarray(pixels, dtype=\"int32\")\n",
    "    data = data.reshape((1,28,28))\n",
    "    # Scale values from -1 to 1\n",
    "    data = ((data / 255) - .5) / .5\n",
    "    target = np.zeros((1, 5))\n",
    "    target[0,label] = 1\n",
    "    row = (data, target, )\n",
    "\n",
    "    split = random.random()\n",
    "    if split > test_cutoff:\n",
    "        test.append(row)\n",
    "    elif split > valid_cutoff:\n",
    "        valid.append(row)\n",
    "    else:\n",
    "        train.append(row)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "class TestNet(Network):\n",
    "    def __init__(self):\n",
    "        self.conv1 = Conv(input_channels=1, output_channels=3, kernel_x=3, kernel_y=3)\n",
    "        self.conv2 = Conv(input_channels=3, output_channels=1, kernel_x=3, kernel_y=3)\n",
    "        self.dense = Dense(input_size=24 ** 2, output_size=5, activation=False)\n",
    "        self.last_input = None\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1.forward(x)\n",
    "        x = self.conv2.forward(x)\n",
    "        x = x.reshape(1, math.prod(x.shape))\n",
    "        x = self.dense.forward(x)\n",
    "        return x\n",
    "\n",
    "    def backward(self, grad, lr):\n",
    "        grad = self.dense.backward(grad, lr)\n",
    "        grad = self.conv2.backward(grad, lr)\n",
    "        grad = self.conv1.backward(grad, lr)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 loss: 0.9882971977021661\n",
      "Epoch 0 valid accuracy: 0.7571810287241149\n",
      "Epoch 1 loss: 0.671348362336631\n",
      "Epoch 1 valid accuracy: 0.7852371409485638\n",
      "Epoch 2 loss: 0.6007674499802047\n",
      "Epoch 2 valid accuracy: 0.7969271877087508\n",
      "Epoch 3 loss: 0.5640594744083588\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[4], line 19\u001B[0m\n\u001B[1;32m     17\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i, img \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(valid):\n\u001B[1;32m     18\u001B[0m     image, target \u001B[38;5;241m=\u001B[39m img\n\u001B[0;32m---> 19\u001B[0m     valid_pred \u001B[38;5;241m=\u001B[39m \u001B[43mnet\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mforward\u001B[49m\u001B[43m(\u001B[49m\u001B[43mimage\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     20\u001B[0m     \u001B[38;5;28;01mmatch\u001B[39;00m[i] \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39margmax(valid_pred) \u001B[38;5;241m==\u001B[39m np\u001B[38;5;241m.\u001B[39margmax(target)\n\u001B[1;32m     22\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mEpoch \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mepoch\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m valid accuracy: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mmatch\u001B[38;5;241m.\u001B[39msum() \u001B[38;5;241m/\u001B[39m match\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m]\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m)\n",
      "Cell \u001B[0;32mIn[3], line 12\u001B[0m, in \u001B[0;36mTestNet.forward\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m     10\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, x):\n\u001B[1;32m     11\u001B[0m     x \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mconv1\u001B[38;5;241m.\u001B[39mforward(x)\n\u001B[0;32m---> 12\u001B[0m     x \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mconv2\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mforward\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     13\u001B[0m     x \u001B[38;5;241m=\u001B[39m x\u001B[38;5;241m.\u001B[39mreshape(\u001B[38;5;241m1\u001B[39m, math\u001B[38;5;241m.\u001B[39mprod(x\u001B[38;5;241m.\u001B[39mshape))\n\u001B[1;32m     14\u001B[0m     x \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdense\u001B[38;5;241m.\u001B[39mforward(x)\n",
      "File \u001B[0;32m~/Personal/nnets/nnets/conv.py:41\u001B[0m, in \u001B[0;36mConv.forward\u001B[0;34m(self, x)\u001B[0m\n\u001B[1;32m     39\u001B[0m output \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mzeros((\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moutput_channels, new_x, new_y))\n\u001B[1;32m     40\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m channel \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39minput_channels):\n\u001B[0;32m---> 41\u001B[0m     unrolled \u001B[38;5;241m=\u001B[39m \u001B[43munroll_image\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m[\u001B[49m\u001B[43mchannel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m:\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mkernel_x\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mkernel_y\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     42\u001B[0m     \u001B[38;5;28;01mfor\u001B[39;00m next_channel \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moutput_channels):\n\u001B[1;32m     43\u001B[0m         kernel \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mweights[channel, next_channel, :]\n",
      "File \u001B[0;32m~/Personal/nnets/nnets/conv.py:11\u001B[0m, in \u001B[0;36munroll_image\u001B[0;34m(image, kernel_x, kernel_y)\u001B[0m\n\u001B[1;32m     10\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21munroll_image\u001B[39m(image, kernel_x, kernel_y):\n\u001B[0;32m---> 11\u001B[0m     unrolled \u001B[38;5;241m=\u001B[39m \u001B[43mview_as_windows\u001B[49m\u001B[43m(\u001B[49m\u001B[43mimage\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m(\u001B[49m\u001B[43mkernel_x\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkernel_y\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     12\u001B[0m     x_size \u001B[38;5;241m=\u001B[39m (image\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m] \u001B[38;5;241m-\u001B[39m (kernel_x \u001B[38;5;241m-\u001B[39m \u001B[38;5;241m1\u001B[39m))\n\u001B[1;32m     13\u001B[0m     y_size \u001B[38;5;241m=\u001B[39m (image\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m1\u001B[39m] \u001B[38;5;241m-\u001B[39m (kernel_y \u001B[38;5;241m-\u001B[39m \u001B[38;5;241m1\u001B[39m))\n",
      "File \u001B[0;32m~/.virtualenvs/nnets/lib/python3.10/site-packages/skimage/util/shape.py:228\u001B[0m, in \u001B[0;36mview_as_windows\u001B[0;34m(arr_in, window_shape, step)\u001B[0m\n\u001B[1;32m    225\u001B[0m arr_shape \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(arr_in\u001B[38;5;241m.\u001B[39mshape)\n\u001B[1;32m    226\u001B[0m window_shape \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(window_shape, dtype\u001B[38;5;241m=\u001B[39marr_shape\u001B[38;5;241m.\u001B[39mdtype)\n\u001B[0;32m--> 228\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[43m(\u001B[49m\u001B[43m(\u001B[49m\u001B[43marr_shape\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mwindow_shape\u001B[49m\u001B[43m)\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m<\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43many\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m:\n\u001B[1;32m    229\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m`window_shape` is too large\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m    231\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m ((window_shape \u001B[38;5;241m-\u001B[39m \u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m<\u001B[39m \u001B[38;5;241m0\u001B[39m)\u001B[38;5;241m.\u001B[39many():\n",
      "File \u001B[0;32m~/.virtualenvs/nnets/lib/python3.10/site-packages/numpy/core/_methods.py:58\u001B[0m, in \u001B[0;36m_any\u001B[0;34m(a, axis, dtype, out, keepdims, where)\u001B[0m\n\u001B[1;32m     55\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m_any\u001B[39m(a, axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, out\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, keepdims\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m, \u001B[38;5;241m*\u001B[39m, where\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m):\n\u001B[1;32m     56\u001B[0m     \u001B[38;5;66;03m# Parsing keyword arguments is currently fairly slow, so avoid it for now\u001B[39;00m\n\u001B[1;32m     57\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m where \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m:\n\u001B[0;32m---> 58\u001B[0m         \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mumr_any\u001B[49m\u001B[43m(\u001B[49m\u001B[43ma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maxis\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdtype\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mout\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkeepdims\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     59\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m umr_any(a, axis, dtype, out, keepdims, where\u001B[38;5;241m=\u001B[39mwhere)\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "net = TestNet()\n",
    "lr = 5e-4\n",
    "epochs = 20\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    epoch_loss = 0\n",
    "    for i, img in enumerate(train):\n",
    "        image, target = img\n",
    "        batch = net.forward(image)\n",
    "\n",
    "        grad = log_loss_grad(softmax(batch), target)\n",
    "        epoch_loss += log_loss(softmax(batch), target)\n",
    "        net.backward(grad, lr)\n",
    "\n",
    "    print(f\"Epoch {epoch} loss: {epoch_loss / len(train)}\")\n",
    "    match = np.zeros(len(valid))\n",
    "    for i, img in enumerate(valid):\n",
    "        image, target = img\n",
    "        valid_pred = net.forward(image)\n",
    "        match[i] = np.argmax(valid_pred) == np.argmax(target)\n",
    "\n",
    "    print(f\"Epoch {epoch} valid accuracy: {match.sum() / match.shape[0]}\")"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
